import java.lang.foreign.*;
import java.lang.invoke.*;
import java.util.function.IntBinaryOperator;
import static java.lang.foreign.ValueLayout.*;

// java --enable-preview --enable-native-access=ALL-UNNAMED -cp out TestFFI
// java --enable-preview --enable-native-access=ALL-UNNAMED -XX:+UnlockExperimentalVMOptions -XX:+UseSystemMemoryBarrier -cp out TestFFI
// java --enable-preview --enable-native-access=ALL-UNNAMED -XX:+UnlockExperimentalVMOptions -XX:+UseSystemMemoryBarrier -XX:+UseEpsilonGC -cp out TestFFI
@SuppressWarnings("OptionalGetWithoutIsPresent")
public class TestFFI {
	static native int nativeAdd(int a, int b);

	@SuppressWarnings("unused")
	static native int nativeAddBack(int a, int b, IntBinaryOperator cb);

	static final MethodHandle mh_add;
	static final MethodHandle mh_add_back;
	static final MemorySegment ms_add;

	static {
		System.loadLibrary("add"); // need add.dll
		var libAdd = SymbolLookup.libraryLookup("add.dll", Arena.global());
		var linker = Linker.nativeLinker();
		mh_add = linker.downcallHandle(libAdd.find("add").get(),
				FunctionDescriptor.of(JAVA_INT, JAVA_INT, JAVA_INT));
		mh_add_back = linker.downcallHandle(libAdd.find("add_back").get(),
				FunctionDescriptor.of(JAVA_INT, JAVA_INT, JAVA_INT, ADDRESS));
		try {
			var local_add = MethodHandles.lookup().findStatic(TestFFI.class, "add",
					MethodType.methodType(int.class, int.class, int.class));
			ms_add = linker.upcallStub(local_add,
					FunctionDescriptor.of(JAVA_INT, JAVA_INT, JAVA_INT), Arena.global());
		} catch (ReflectiveOperationException e) {
			throw new RuntimeException(e);
		}
	}

	static int add(int a, int b) {
		return a + b;
	}

	static final int LOOP_COUNT = 1_0000_0000;

	static void testJNI() {
		long t = System.nanoTime();
		int n = 0;
		for (int i = 0; i < LOOP_COUNT; i++)
			n += nativeAdd(n, i);
		t = (System.nanoTime() - t) / 1_000_000;
		System.out.println("testJNI_Direct  : " + n + ", " + t + " ms");

//		t = System.nanoTime();
//		n = 0;
//		for (int i = 0; i < LOOP_COUNT; i++)
//			n += nativeAddBack(n, i, (a, b) -> a + b); // slow
//		t = (System.nanoTime() - t) / 1_000_000;
//		System.out.println("testJNI_CallBack: " + n + ", " + t + " ms");
	}

	static void testFFI() throws Throwable {
		long t = System.nanoTime();
		int n = 0;
		for (int i = 0; i < LOOP_COUNT; i++)
			n += (int)mh_add.invokeExact(n, i);
		t = (System.nanoTime() - t) / 1_000_000;
		System.out.println("testFFI_Direct  : " + n + ", " + t + " ms");

		t = System.nanoTime();
		n = 0;
		for (int i = 0; i < LOOP_COUNT; i++)
			n += (int)mh_add_back.invokeExact(n, i, ms_add);
		t = (System.nanoTime() - t) / 1_000_000;
		System.out.println("testFFI_CallBack: " + n + ", " + t + " ms");
	}

	public static void main(String[] args) throws Throwable {
		for (int i = 0; i < 5; i++) {
			System.out.println("-------- " + i);
			testJNI();
			testFFI();
		}
	}

//	public static void main(String[] args) {
//		System.out.println(nativeAddBack(1, 2, (a, b) -> a + b));
//		System.out.println(nativeAddBack(1, 2, (a, b) -> a - b));
//	}
}

// https://developer.okta.com/blog/2022/04/08/state-of-ffi-java
